{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multiple Adjacency Spectral Embedding (MASE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Multiple Adjacency Spectral Embedding (MASE) is an extension of Adjacency Spectral Embedding (see ASE [tutorial](https://graspy.neurodata.io/tutorials/embedding/adjacencyspectralembed)) for an arbitrary number of graphs.  Once graphs are embedded, the low-dimensional Euclidean representation can be used to visualize the latent positions of vertices, perform inference, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Replicating Omnibus Embedding Tutorial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to demonstrate how to use the ``MultipleASE`` class, we will use a simple example: two 2-block stochastic block models (SBMs) with different block probabilities.  Indeed, we will use the same models as in the omnibus [tutorial](https://graspy.neurodata.io/tutorials/embedding/omnibus) to facilitate direct comparison. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulate two different graphs using stochastic block models (SBM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We sample two 2-block SBMs (undirected, no self-loops) with 50 vertices, each block containing 25 vertices ($n = [25, 25]$), and with the following block connectivity matrices:\n",
    "\n",
    "\\begin{align*}\n",
    "B_1 = \n",
    "\\begin{bmatrix}0.3 & 0.1\\\\\n",
    "0.1 & 0.7\n",
    "\\end{bmatrix},~\n",
    "B_2 = \\begin{bmatrix}0.3 & 0.1\\\\\n",
    "0.1 & 0.3\n",
    "\\end{bmatrix}\n",
    "\\end{align*}\n",
    "\n",
    "The only difference between the two is the within-block probability for the second block. We sample $G_1 \\sim \\text{SBM}(n, B_1)$ and $G_2 \\sim \\text{SBM}(n, B_2)$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graspologic.simulations import sbm\n",
    "\n",
    "n = [25, 25]\n",
    "B1 = [[0.3, 0.1],\n",
    "      [0.1, 0.7]]\n",
    "B2 = [[0.3, 0.1],\n",
    "      [0.1, 0.3]]\n",
    "\n",
    "np.random.seed(8)\n",
    "G1 = sbm(n, B1)\n",
    "G2 = sbm(n, B2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize the graphs using heatmap\n",
    "\n",
    "We visualize the sampled graphs using the ``heatmap()`` function. ``heatmap()`` will plot the adjacency matrix, where the colors represent the weight of each edge. In this case, we have binary graphs so the values will be either 0 or 1. \n",
    "\n",
    "We see that there is clear block structure to the graphs. Furthermore, the lower right quarter of $G_1$, representing the within-group connections for the second group, is more dense than that of $G_2$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graspologic.plot import heatmap\n",
    "\n",
    "heatmap(G1, figsize=(7,7), title=\"Visualization of Graph 1\", cbar=False)\n",
    "_ = heatmap(G2, figsize=(7,7), title=\"Visualization of Graph 2\", cbar=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Embed the two graphs using MASE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just as ASE fits a single input graph to the Random Dot Product Graph (RDPG) model, MASE fits many input graphs to the common subspace independent-edge (COSIE) model, which is an extension of the RDPG model to multiple graphs. The COSIE model assumes that all of the graphs have important shared properties, but can also have meaningful differences.  As a result, COSIE decomposes the set of graphs into a set of shared latent positions $V$ that describe similarities, and a set of score matrices $R^{(i)}$ that describe how each individual graph is different.  Mathematically, for the adjacency matrix of an undirected graph $A^{(i)}$,\n",
    "$$A^{(i)} = VR^{(i)}V^T$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In general, 2 clusters can be faithfully represented in 2 dimensions, so we specify the number of components as 2. Then, the ``fit_transform()`` method returns the estimated latent positions $\\hat{V}$, which we can see has 50 entries corresponding to the 50 vertices embedded in 2 dimensions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graspologic.embed import MultipleASE as MASE\n",
    "\n",
    "embedder = MASE(n_components=2)\n",
    "V_hat = embedder.fit_transform([G1, G2])\n",
    "print(V_hat.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize Common Latent Positions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since the two graphs have clear block structures, with higher within-block probabilities than between-block, we see two distinct \"clusters\" when we visualize the shared latent positions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def vis_latent(V, title, predicted=False, pred_labels=None):\n",
    "    fig, ax = plt.subplots(figsize=(10,10))\n",
    "    if predicted:\n",
    "        ax.scatter(V_hat[pred_labels==0, 0], V_hat[pred_labels==0, 1], c=\"blue\", label=\"Predicted Block 1\")\n",
    "        ax.scatter(V_hat[pred_labels==1, 0], V_hat[pred_labels==1, 1], c=\"red\", label=\"Predicted Block 2\")\n",
    "    else:\n",
    "        ax.scatter(V[:25, 0], V[:25, 1], c=\"blue\", label=\"Block 1\")\n",
    "        ax.scatter(V[25:, 0], V[25:, 1], c=\"red\", label=\"Block 2\")\n",
    "        \n",
    "    ax.legend(prop={'size':20})\n",
    "    ax.set_xticklabels([])\n",
    "    ax.set_yticklabels([])\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    _ = ax.set_title(title, fontsize=20)\n",
    "\n",
    "vis_latent(V_hat, \"Common Latent Positions from MASE Embedding\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Recover Original Communities"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can apply a simple clustering algorithm (kmeans) to recover the original communities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Cluster V_hat\n",
    "from graspologic.cluster import KMeansCluster\n",
    "clusterer = KMeansCluster()\n",
    "pred_labels = clusterer.fit_predict(V_hat)\n",
    "\n",
    "#Remap labels\n",
    "from graspologic.utils import remap_labels\n",
    "true_labels = [0]*25 + [1]*25\n",
    "remapped_labels = remap_labels(true_labels, pred_labels)\n",
    "\n",
    "#Visualize latent positions\n",
    "vis_latent(V_hat, \"Predicted Communities from Common Latent Positions\", True, remapped_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Even with only one graph sampled from each SBM, we can perfectly recover the two communities!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sampling More Graphs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can recover a more accurate representation of the two SBM populations with more sample graphs.  Let's try with 100 graphs sampled from each."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "G1 = [sbm(n, B1) for i in range(100)]\n",
    "G2 = [sbm(n, B2) for i in range(100)]\n",
    "\n",
    "V_hat = embedder.fit_transform(G1 + G2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizing Common Latent Subspace"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, the clustering is even more distinct, as we would expect with more samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_latent(V_hat, \"Common Latent Subspace from MASE Embedding with Many Sampled Graphs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize Score Matrices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have more samples, we can apply the same process to visualize the individual score matrices.  Note, since the score matrices are symmetric, $2 \\times 2$ matrices, we only need a 3-dimensional vector to represent each.  In this case, we can clearly see two distinct clusters: one for the score matrices from the first SBM population another for the score matrices from the second."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graspologic.plot import pairplot\n",
    "\n",
    "#Extract Score Matrices (in vector form)\n",
    "R_1_hat = np.reshape(embedder.scores_[:100], (100, 4))\n",
    "R_2_hat = np.reshape(embedder.scores_[100:], (100, 4))\n",
    "\n",
    "#Omitting redundant entries due to symmetry\n",
    "R_1_hat = np.delete(R_1_hat, 1, axis=1)\n",
    "R_2_hat = np.delete(R_2_hat, 1, axis=1)\n",
    "\n",
    "#Combining into a single array\n",
    "R_hat = np.concatenate((R_1_hat, R_2_hat), axis=0)\n",
    "\n",
    "#Visualize\n",
    "labels = [\"SBM 1\"]*100 + [\"SBM 2\"]*100\n",
    "_ = pairplot(R_hat, labels, title=\"Score Matrices from MASE Embedding\",\n",
    "             palette={\"green\", \"purple\"}, legend_name=\"Population\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Recover Original SBM Membership"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As before, we can apply a simple clustering algorithm (kmeans) to recover from which SBM population each graph was generated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Cluster R_hat\n",
    "pred_labels = clusterer.fit_predict(R_hat)\n",
    "true_labels = [0]*100 + [1]*100\n",
    "remapped_labels = remap_labels(true_labels, pred_labels)\n",
    "pred_sbm = []\n",
    "for label in remapped_labels:\n",
    "    if label == 0:\n",
    "        pred_sbm.append(\"SBM 1\")\n",
    "    else:\n",
    "        pred_sbm.append(\"SBM 2\")\n",
    "\n",
    "#Visualize Results\n",
    "_ = pairplot(R_hat, pred_sbm, title=\"Predicted SBM Membership from Score Matrices\",\n",
    "             palette={\"green\", \"purple\"}, legend_name=\"Population\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With this method we can perfectly recover the original SBM population to which each graph belongs!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}